"""
Preprocessing Script for Structured3D

Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""

import argparse
import io
import multiprocessing as mp
import os
import zipfile
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat

import cv2
import numpy as np
import PIL
import torch
from PIL import Image

from ponder.datasets.transform import GridSample

VALID_CLASS_IDS_25 = (
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    11,
    14,
    15,
    16,
    17,
    18,
    19,
    22,
    24,
    25,
    32,
    34,
    35,
    38,
    39,
    40,
)
CLASS_LABELS_25 = (
    "wall",
    "floor",
    "cabinet",
    "bed",
    "chair",
    "sofa",
    "table",
    "door",
    "window",
    "picture",
    "desk",
    "shelves",
    "curtain",
    "dresser",
    "pillow",
    "mirror",
    "ceiling",
    "refrigerator",
    "television",
    "nightstand",
    "sink",
    "lamp",
    "otherstructure",
    "otherfurniture",
    "otherprop",
)


def normal_from_cross_product(points_2d: np.ndarray) -> np.ndarray:
    xyz_points_pad = np.pad(points_2d, ((0, 1), (0, 1), (0, 0)), mode="symmetric")
    xyz_points_ver = (xyz_points_pad[:, :-1, :] - xyz_points_pad[:, 1:, :])[:-1, :, :]
    xyz_points_hor = (xyz_points_pad[:-1, :, :] - xyz_points_pad[1:, :, :])[:, :-1, :]
    xyz_normal = np.cross(xyz_points_hor, xyz_points_ver)
    xyz_dist = np.linalg.norm(xyz_normal, axis=-1, keepdims=True)
    xyz_normal = np.divide(
        xyz_normal, xyz_dist, out=np.zeros_like(xyz_normal), where=xyz_dist != 0
    )
    return xyz_normal


class Structured3DReader:
    def __init__(self, files):
        super().__init__()
        if isinstance(files, str):
            files = [files]
        self.readers = [zipfile.ZipFile(f, "r") for f in files]
        self.names_mapper = dict()
        for idx, reader in enumerate(self.readers):
            for name in reader.namelist():
                self.names_mapper[name] = idx

    def filelist(self):
        return list(self.names_mapper.keys())

    def listdir(self, dir_name):
        dir_name = dir_name.lstrip(os.path.sep).rstrip(os.path.sep)
        file_list = list(
            np.unique(
                [
                    f.replace(dir_name + os.path.sep, "", 1).split(os.path.sep)[0]
                    for f in self.filelist()
                    if f.startswith(dir_name + os.path.sep)
                ]
            )
        )
        if "" in file_list:
            file_list.remove("")
        return file_list

    def read(self, file_name):
        split = self.names_mapper[file_name]
        return self.readers[split].read(file_name)

    def read_camera(self, camera_path):
        z2y_top_m = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=np.float32)
        cam_extr = np.fromstring(self.read(camera_path), dtype=np.float32, sep=" ")
        cam_t = np.matmul(z2y_top_m, cam_extr[:3] / 1000)
        if cam_extr.shape[0] > 3:
            cam_front, cam_up = cam_extr[3:6], cam_extr[6:9]
            cam_n = np.cross(cam_front, cam_up)
            cam_r = np.stack((cam_front, cam_up, cam_n), axis=1).astype(np.float32)
            cam_r = np.matmul(z2y_top_m, cam_r)
            cam_f = cam_extr[9:11]
        else:
            cam_r = np.eye(3, dtype=np.float32)
            cam_f = None
        return cam_r, cam_t, cam_f

    def read_depth(self, depth_path):
        depth = cv2.imdecode(
            np.frombuffer(self.read(depth_path), np.uint8), cv2.IMREAD_UNCHANGED
        )[..., np.newaxis]
        depth[depth == 0] = 65535
        return depth

    def read_color(self, color_path):
        color = cv2.imdecode(
            np.frombuffer(self.read(color_path), np.uint8), cv2.IMREAD_UNCHANGED
        )[..., :3][..., ::-1]
        return color

    def read_segment(self, segment_path):
        segment = np.array(PIL.Image.open(io.BytesIO(self.read(segment_path))))[
            ..., np.newaxis
        ]
        return segment


def parse_scene(
    scene,
    dataset_root,
    output_root,
    ignore_index=-1,
    grid_size=None,
    fuse_prsp=True,
    fuse_pano=True,
    parse_rgbd=False,
    plugin_rgbd=False,
    vis=False,
):
    assert fuse_prsp or fuse_pano or parse_rgbd or plugin_rgbd
    reader = Structured3DReader(
        [
            os.path.join(dataset_root, f)
            for f in os.listdir(dataset_root)
            if f.endswith(".zip")
        ]
    )
    scene_id = int(os.path.basename(scene).split("_")[-1])
    if scene_id < 3000:
        split = "train"
    elif 3000 <= scene_id < 3250:
        split = "val"
    else:
        split = "test"

    print(f"Processing: {scene} in {split}")
    scene_output_path = os.path.join(output_root, split, os.path.basename(scene))
    os.makedirs(scene_output_path, exist_ok=True)
    rooms = reader.listdir(os.path.join("Structured3D", scene, "2D_rendering"))
    for room in rooms:
        room_path = os.path.join("Structured3D", scene, "2D_rendering", room)
        coord_list = list()
        color_list = list()
        normal_list = list()
        segment_list = list()
        if fuse_prsp or parse_rgbd or plugin_rgbd:
            prsp_path = os.path.join(room_path, "perspective", "full")
            frames = reader.listdir(prsp_path)

            for frame in frames:
                try:
                    cam_r, cam_t, cam_f = reader.read_camera(
                        os.path.join(prsp_path, frame, "camera_pose.txt")
                    )
                    depth = reader.read_depth(
                        os.path.join(prsp_path, frame, "depth.png")
                    )
                    color = reader.read_color(
                        os.path.join(prsp_path, frame, "rgb_rawlight.png")
                    )
                    segment = reader.read_segment(
                        os.path.join(prsp_path, frame, "semantic.png")
                    )
                except:
                    print(
                        f"Skipping {scene}_room{room}_frame{frame} perspective view due to loading error"
                    )
                else:
                    fx, fy = cam_f
                    height, width = depth.shape[0], depth.shape[1]
                    pixel = np.transpose(np.indices((width, height)), (2, 1, 0))
                    pixel = pixel.reshape((-1, 2))
                    pixel = np.hstack((pixel, np.ones((pixel.shape[0], 1))))
                    k = np.diag([1.0, 1.0, 1.0])

                    k[0, 2] = width / 2
                    k[1, 2] = height / 2

                    k[0, 0] = k[0, 2] / np.tan(fx)
                    k[1, 1] = k[1, 2] / np.tan(fy)
                    coord = (
                        depth.reshape((-1, 1)) * (np.linalg.inv(k) @ pixel.T).T
                    ).reshape(height, width, 3)
                    coord = coord @ np.array([[0, 0, 1], [0, -1, 0], [1, 0, 0]])
                    normal = normal_from_cross_product(coord)

                    # Filtering invalid points
                    view_dist = np.maximum(
                        np.linalg.norm(coord, axis=-1, keepdims=True), float(10e-5)
                    )
                    cosine_dist = np.sum(
                        (coord * normal / view_dist), axis=-1, keepdims=True
                    )
                    cosine_dist = np.abs(cosine_dist)
                    mask = ((cosine_dist > 0.15) & (depth < 65535) & (segment > 0))[
                        ..., 0
                    ].reshape(-1)

                    if fuse_prsp:
                        coord = np.matmul(coord / 1000, cam_r.T) + cam_t
                        normal = normal_from_cross_product(coord)

                    if parse_rgbd or plugin_rgbd:
                        depth = depth[:, :, 0]
                        extrinsic = np.eye(4)
                        extrinsic[:3, :3] = cam_r
                        extrinsic[:3, 3] = cam_t
                        extrinsic = np.array(
                            [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
                        ) @ np.linalg.inv(
                            np.array(
                                [
                                    [0, 0, 1, 0],
                                    [0, -1, 0, 0],
                                    [1, 0, 0, 0],
                                    [0, 0, 0, 1],
                                ]
                            )
                            @ np.linalg.inv(extrinsic)
                        )

                        semantic_map = (
                            np.ones_like(depth, dtype=np.int64) * ignore_index
                        )
                        for idx, value in enumerate(VALID_CLASS_IDS_25):
                            segment_mask = np.all(segment == value, axis=-1)
                            semantic_map[segment_mask] = idx

                    if sum(mask) > 0:
                        if fuse_prsp:
                            coord_list.append(coord.reshape(-1, 3)[mask])
                            color_list.append(color.reshape(-1, 3)[mask])
                            normal_list.append(normal.reshape(-1, 3)[mask])
                            segment_list.append(segment.reshape(-1, 1)[mask])
                        if parse_rgbd or plugin_rgbd:
                            rgbd_dict = dict(
                                intrinsic=k,
                                extrinsic=extrinsic,
                                rgb=color,
                                depth=depth,
                                depth_mask=mask.reshape(depth.shape[0], depth.shape[1]),
                                semantic_map=semantic_map,
                            )
                            room_output_path = os.path.join(
                                output_root,
                                split,
                                os.path.basename(scene),
                                f"room_{room}_rgbd/",
                            )
                            os.makedirs(room_output_path, exist_ok=True)
                            torch.save(
                                rgbd_dict,
                                os.path.join(room_output_path, f"frame_{frame}.pth"),
                            )
                    else:
                        print(
                            f"Skipping {scene}_room{room}_frame{frame} perspective view due to all points are filtered out"
                        )

        if plugin_rgbd:  # in plugin mode, we only extract RGB-D images
            continue

        if fuse_pano:
            pano_path = os.path.join(room_path, "panorama")
            try:
                _, cam_t, _ = reader.read_camera(
                    os.path.join(pano_path, "camera_xyz.txt")
                )
                depth = reader.read_depth(os.path.join(pano_path, "full", "depth.png"))
                color = reader.read_color(
                    os.path.join(pano_path, "full", "rgb_rawlight.png")
                )
                segment = reader.read_segment(
                    os.path.join(pano_path, "full", "semantic.png")
                )
            except:
                print(f"Skipping {scene}_room{room} panorama view due to loading error")
            else:
                p_h, p_w = depth.shape[:2]
                p_a = np.arange(p_w, dtype=np.float32) / p_w * 2 * np.pi - np.pi
                p_b = np.arange(p_h, dtype=np.float32) / p_h * np.pi * -1 + np.pi / 2
                p_a = np.tile(p_a[None], [p_h, 1])[..., np.newaxis]
                p_b = np.tile(p_b[:, None], [1, p_w])[..., np.newaxis]
                p_a_sin, p_a_cos, p_b_sin, p_b_cos = (
                    np.sin(p_a),
                    np.cos(p_a),
                    np.sin(p_b),
                    np.cos(p_b),
                )
                x = depth * p_a_cos * p_b_cos
                y = depth * p_b_sin
                z = depth * p_a_sin * p_b_cos
                coord = np.concatenate([x, y, z], axis=-1) / 1000
                normal = normal_from_cross_product(coord)

                # Filtering invalid points
                view_dist = np.maximum(
                    np.linalg.norm(coord, axis=-1, keepdims=True), float(10e-5)
                )
                cosine_dist = np.sum(
                    (coord * normal / view_dist), axis=-1, keepdims=True
                )
                cosine_dist = np.abs(cosine_dist)
                mask = ((cosine_dist > 0.15) & (depth < 65535) & (segment > 0))[
                    ..., 0
                ].reshape(-1)
                coord = coord + cam_t

                if sum(mask) > 0:
                    coord_list.append(coord.reshape(-1, 3)[mask])
                    color_list.append(color.reshape(-1, 3)[mask])
                    normal_list.append(normal.reshape(-1, 3)[mask])
                    segment_list.append(segment.reshape(-1, 1)[mask])
                else:
                    print(
                        f"Skipping {scene}_room{room} panorama view due to all points are filtered out"
                    )

        if len(coord_list) > 0:
            coord = np.concatenate(coord_list, axis=0)
            coord = coord @ np.array([[1, 0, 0], [0, 0, 1], [0, 1, 0]])
            color = np.concatenate(color_list, axis=0)
            normal = np.concatenate(normal_list, axis=0)
            normal = normal @ np.array([[1, 0, 0], [0, 0, 1], [0, 1, 0]])
            segment = np.concatenate(segment_list, axis=0)
            segment25 = np.ones_like(segment, dtype=np.int64) * ignore_index
            for idx, value in enumerate(VALID_CLASS_IDS_25):
                mask = np.all(segment == value, axis=-1)
                segment25[mask] = idx

            data_dict = dict(
                coord=coord.astype("float32"),
                color=color.astype("uint8"),
                normal=normal.astype("float32"),
                semantic_gt=segment25.astype("int16"),
            )
            # Grid sampling data
            if grid_size is not None:
                sampler = GridSample(
                    grid_size=grid_size,
                    keys=("coord", "color", "normal", "semantic_gt"),
                )
                data_dict = sampler(data_dict)
            torch.save(data_dict, os.path.join(scene_output_path, f"room_{room}.pth"))

            if vis:
                from ponder.utils.visualization import save_point_cloud

                os.makedirs("./vis", exist_ok=True)
                save_point_cloud(
                    coord, color / 255, f"./vis/{scene}_room{room}_color.ply"
                )
                save_point_cloud(
                    coord, (normal + 1) / 2, f"./vis/{scene}_room{room}_normal.ply"
                )
        else:
            print(f"Skipping {scene}_room{room} due to no valid points")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_root",
        required=True,
        help="Path to the ScanNet dataset containing scene folders.",
    )
    parser.add_argument(
        "--output_root",
        required=True,
        help="Output path where train/val folders will be located.",
    )
    parser.add_argument(
        "--num_workers",
        default=mp.cpu_count(),
        type=int,
        help="Num workers for preprocessing.",
    )
    parser.add_argument(
        "--grid_size", default=None, type=float, help="Grid size for grid sampling."
    )
    parser.add_argument("--ignore_index", default=-1, type=float, help="Ignore index.")
    parser.add_argument(
        "--fuse_prsp", action="store_true", help="Whether to fuse perspective view."
    )
    parser.add_argument(
        "--fuse_pano", action="store_true", help="Whether to fuse panorama view."
    )
    parser.add_argument(
        "--parse_rgbd", action="store_true", help="Whether to parse RGB-D images."
    )
    parser.add_argument(
        "--plugin_rgbd",
        action="store_true",
        help="Whether to parse RGB-D images only as a plugin mode.",
    )
    config = parser.parse_args()

    reader = Structured3DReader(
        [
            os.path.join(config.dataset_root, f)
            for f in os.listdir(config.dataset_root)
            if f.endswith(".zip")
        ]
    )

    scenes_list = reader.listdir("Structured3D")
    scenes_list = sorted(scenes_list)
    os.makedirs(os.path.join(config.output_root, "train"), exist_ok=True)
    os.makedirs(os.path.join(config.output_root, "val"), exist_ok=True)
    os.makedirs(os.path.join(config.output_root, "test"), exist_ok=True)

    # Preprocess data.
    print("Processing scenes...")
    pool = ProcessPoolExecutor(max_workers=config.num_workers)
    _ = list(
        pool.map(
            parse_scene,
            scenes_list,
            repeat(config.dataset_root),
            repeat(config.output_root),
            repeat(config.ignore_index),
            repeat(config.grid_size),
            repeat(config.fuse_prsp),
            repeat(config.fuse_pano),
            repeat(config.parse_rgbd),
            repeat(config.plugin_rgbd),
        )
    )
    pool.shutdown()
